import numpy as np
import gym
import zipfile
import scipy, scipy.misc
from urllib.request import urlretrieve

import scipy.integrate

import os, sys
solve_ivp = scipy.integrate.solve_ivp
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

from simulators.utils import to_pickle, from_pickle
from skimage.transform import resize
import pdb


def get_theta(obs):
    '''Transforms coordinate basis from the defaults of the gym pendulum env.'''
    theta = np.arctan2(obs[0], -obs[1])
    theta = theta + np.pi/2
    theta = theta + 2*np.pi if theta < -np.pi else theta
    theta = theta - 2*np.pi if theta > np.pi else theta
    return theta


def preproc(X, side):
    '''Crops, downsamples, desaturates, etc. the rgb pendulum observation.'''
    X = X[..., 0][240:-120, 120:-120] - X[..., 1][240:-120, 120:-120]
    return resize(X, [int(side/2), side]) / 255.


# def preproc(X, side):
#     '''Crops, downsamples, desaturates, etc. the rgb pendulum observation.'''
#     pdb.set_trace()
#     X = X[...,0][440:-220,330:-330] - X[...,1][440:-220,330:-330]
#     return resize(X, [int(side), side]) / 255.

def sample_gym(seed=0, timesteps=103, trials=200, side=28, min_angle=0., max_angle=np.pi/6,
               verbose=False, env_name='Pendulum-v0'):

    gym_settings = locals()
    if verbose:
        print("Making a dataset of pendulum pixel observations.")
        print("Edit 5/20/19: you may have to rewrite the `preproc` function depending on your screen size.")
    env = gym.make(env_name)
    env.reset()
    env.seed(seed)

    canonical_coords, frames = [], []
    for step in range(trials*timesteps):

        if step % timesteps == 0:
            angle_ok = False

            while not angle_ok:
                obs = env.reset()
                theta_init = np.abs(get_theta(obs))
                if verbose:
                    print("\tCalled reset. Max angle= {:.3f}".format(theta_init))
                if theta_init > min_angle and theta_init < max_angle:
                    angle_ok = True
                  
            if verbose:
                print("\tRunning environment...")
                
        frames.append(preproc(env.render('rgb_array'), side))
        obs, _, _, _ = env.step([0.])
        theta, dtheta = get_theta(obs), obs[-1]

        # The constant factor of 0.25 comes from saying plotting H = PE + KE*c
        # and choosing c such that total energy is as close to constant as
        # possible. It's not perfect, but the best we can do.
        canonical_coords.append(np.array([theta, 0.25 * dtheta]))
    
    canonical_coords = np.stack(canonical_coords).reshape(trials*timesteps, -1)
    frames = np.stack(frames).reshape(trials*timesteps, -1)
    return canonical_coords, frames, gym_settings


def make_gym_dataset(test_split=0.2, **kwargs):
    '''Constructs a dataset of observations from an OpenAI Gym env'''
    canonical_coords, frames, gym_settings = sample_gym(**kwargs)
    
    coords, dcoords = [], []  # position and velocity data (canonical coordinates)
    pixels, dpixels = [], []  # position and velocity data (pixel space)
    next_pixels, next_dpixels = [], []  # (pixel space measurements, 1 timestep in future)

    trials = gym_settings['trials']
    for cc, pix in zip(np.split(canonical_coords, trials), np.split(frames, trials)):
        # calculate cc offsets
        cc = cc[1:]
        dcc = cc[1:] - cc[:-1]
        cc = cc[1:]

        # concat adjacent frames to get velocity information
        # now the pixel arrays have same information as canonical coords
        # ...but in a different (highly nonlinear) basis
        p = np.concatenate([pix[:-1], pix[1:]], axis=-1)
        
        dp = p[1:] - p[:-1]
        p = p[1:]

        # calculate the same quantities, one timestep in the future
        next_p, next_dp = p[1:], dp[1:]
        p, dp = p[:-1], dp[:-1]
        cc, dcc = cc[:-1], dcc[:-1]

        # append to lists
        coords.append(cc)
        dcoords.append(dcc)
        pixels.append(p)
        dpixels.append(dp)
        next_pixels.append(next_p)
        next_dpixels.append(next_dp)

    # concatenate across trials
    data = {'coords': coords, 'dcoords': dcoords,
            'pixels': pixels, 'dpixels': dpixels, 
            'next_pixels': next_pixels, 'next_dpixels': next_dpixels}
    data = {k: np.concatenate(v) for k, v in data.items()}

    # make a train/test split
    split_ix = int(data['coords'].shape[0] * test_split)
    split_data = {}
    for k, v in data.items():
        split_data[k], split_data['test_' + k] = v[split_ix:], v[:split_ix]
    data = split_data

    gym_settings['timesteps'] -= 3  # from all the offsets computed above
    data['meta'] = gym_settings

    return data


def get_dataset(experiment_name, save_dir, **kwargs):
    '''Returns a dataset bult on top of OpenAI Gym observations. Also constructs
    the dataset if no saved version is available.'''

    if experiment_name == "pendulum":
        env_name = "Pendulum-v0"
    elif experiment_name == "acrobot":
        env_name = "Acrobot-v1"
    else:
        assert experiment_name in ['pendulum', '3body', 'pend-sim', 'pend-real']

    path = '{}/{}-pixels-dataset.pkl'.format(save_dir, experiment_name)

    try:
        data = from_pickle(path)
        print("Successfully loaded data from {}".format(path))
    except:
        print("Had a problem loading data from {}. Rebuilding dataset...".format(path))
        if experiment_name == 'pendulum':
            data = make_gym_dataset(**kwargs)
        elif experiment_name in ['pend-sim', 'pend-real']:
            data = make_pendulum_dataset(experiment_name, save_dir)
        else:
            data = make_orbits_dataset(**kwargs)
        to_pickle(data, path)

    return data


def make_pendulum_dataset(experiment_name, save_dir, test_split=0.8):
    if experiment_name == 'pend-sim':
        dataset_name = 'pendulum_h_1'
    else:
        dataset_name = 'real_pend_h_1'

    url = 'http://science.sciencemag.org/highwire/filestream/590089/field_highwire_adjunct_files/2/'
    os.makedirs(save_dir) if not os.path.exists(save_dir) else None
    out_file = '{}/invar_datasets.zip'.format(save_dir)

    urlretrieve(url, out_file)

    data_str = read_lipson(dataset_name, save_dir)
    state, names = str2array(data_str)

    # put data in a dictionary structure
    data = {k: state[:, i:i + 1] for i, k in enumerate(names)}
    data['x'] = state[:, 2:4]
    data['dx'] = (data['x'][1:] - data['x'][:-1]) / (data['t'][1:] - data['t'][:-1])
    data['x'] = data['x'][:-1]

    # make a train/test split while preserving order of data
    # there's no great way to do this.
    # here we just put the test set in the middle of the sequence
    train_set_size = int(len(data['x']) * test_split)
    test_set_size = int(len(data['x']) * (1 - test_split))
    test_start_ix = train_set_size  # int(train_set_size/2)
    a = test_start_ix
    b = test_start_ix + test_set_size

    split_data = {}
    for k, v in data.items():
        split_data[k] = np.concatenate([v[:a], v[b:]], axis=0)
        split_data['test_' + k] = v[a:b]
    data = split_data
    return data


### FOR DYNAMICS IN ANALYSIS SECTION ###
def hamiltonian_fn(coords):
    k = 1.9  # this coefficient must be fit to the data
    q, p = np.split(coords, 2)
    H = k * (1 - np.cos(q)) + p**2  # pendulum hamiltonian
    return H


def dynamics_fn(t, coords):
    dcoords = autograd.grad(hamiltonian_fn)(coords)
    dqdt, dpdt = np.split(dcoords, 2)
    S = -np.concatenate([dpdt, -dqdt], axis=-1)
    return S


##### ENERGY #####
def potential_energy(state):
    '''U=\sum_i,j>i G m_i m_j / r_ij'''
    tot_energy = np.zeros((1, 1, state.shape[2]))
    for i in range(state.shape[0]):
        for j in range(i + 1, state.shape[0]):
            r_ij = ((state[i:i + 1, 1:3] - state[j:j + 1, 1:3]) ** 2).sum(1, keepdims=True) ** .5
            m_i = state[i:i + 1, 0:1]
            m_j = state[j:j + 1, 0:1]
            tot_energy += m_i * m_j / r_ij
    U = -tot_energy.sum(0).squeeze()
    return U


def kinetic_energy(state):
    '''T=\sum_i .5*m*v^2'''
    energies = .5 * state[:, 0:1] * (state[:, 3:5] ** 2).sum(1, keepdims=True)
    T = energies.sum(0).squeeze()
    return T


def total_energy(state):
    return potential_energy(state) + kinetic_energy(state)


##### DYNAMICS #####
def get_accelerations(state, epsilon=0):
    # shape of state is [bodies x properties]
    net_accs = []  # [nbodies x 2]
    for i in range(state.shape[0]):  # number of bodies
        other_bodies = np.concatenate([state[:i, :], state[i + 1:, :]], axis=0)
        displacements = other_bodies[:, 1:3] - state[i, 1:3]  # indexes 1:3 -> pxs, pys
        distances = (displacements ** 2).sum(1, keepdims=True) ** 0.5
        masses = other_bodies[:, 0:1]  # index 0 -> mass
        pointwise_accs = masses * displacements / (distances ** 3 + epsilon)  # G=1
        net_acc = pointwise_accs.sum(0, keepdims=True)
        net_accs.append(net_acc)
    net_accs = np.concatenate(net_accs, axis=0)
    return net_accs


def update(t, state):
    state = state.reshape(-1, 5)  # [bodies, properties]
    deriv = np.zeros_like(state)
    deriv[:, 1:3] = state[:, 3:5]  # dx, dy = vx, vy
    deriv[:, 3:5] = get_accelerations(state)
    return deriv.reshape(-1)


##### INTEGRATION SETTINGS #####
def get_orbit(state, update_fn=update, t_points=100, t_span=None, nbodies=3, **kwargs):
    if t_span is None:
        t_span = [0, 2]
    if not 'rtol' in kwargs.keys():
        kwargs['rtol'] = 1e-9

    orbit_settings = locals()

    nbodies = state.shape[0]
    t_eval = np.linspace(t_span[0], t_span[1], t_points)
    orbit_settings['t_eval'] = t_eval

    path = solve_ivp(fun=update_fn, t_span=t_span, y0=state.flatten(),
                     t_eval=t_eval, **kwargs)
    orbit = path['y'].reshape(nbodies, 5, t_points)
    return orbit, orbit_settings


##### INITIALIZE THE TWO BODIES #####
def rotate2d(p, theta):
    c, s = np.cos(theta), np.sin(theta)
    R = np.array([[c, -s], [s, c]])
    return (R @ p.reshape(2, 1)).squeeze()


def random_config(nu=2e-1, min_radius=0.9, max_radius=1.2):
    '''This is not principled at all yet'''
    state = np.zeros((3, 5))
    state[:, 0] = 1
    p1 = 2 * np.random.rand(2) - 1
    r = np.random.rand() * (max_radius - min_radius) + min_radius

    p1 *= r / np.sqrt(np.sum((p1 ** 2)))
    p2 = rotate2d(p1, theta=2 * np.pi / 3)
    p3 = rotate2d(p2, theta=2 * np.pi / 3)

    # # velocity that yields a circular orbit
    v1 = rotate2d(p1, theta=np.pi / 2)
    v1 = v1 / r ** 1.5
    v1 = v1 * np.sqrt(np.sin(np.pi / 3) / (2 * np.cos(np.pi / 6) ** 2))  # scale factor to get circular trajectories
    v2 = rotate2d(v1, theta=2 * np.pi / 3)
    v3 = rotate2d(v2, theta=2 * np.pi / 3)

    # make the circular orbits slightly chaotic
    v1 *= 1 + nu * (2 * np.random.rand(2) - 1)
    v2 *= 1 + nu * (2 * np.random.rand(2) - 1)
    v3 *= 1 + nu * (2 * np.random.rand(2) - 1)

    state[0, 1:3], state[0, 3:5] = p1, v1
    state[1, 1:3], state[1, 3:5] = p2, v2
    state[2, 1:3], state[2, 3:5] = p3, v3
    return state


##### INTEGRATE AN ORBIT OR TWO #####
def sample_orbits(timesteps=20, trials=5000, nbodies=3, orbit_noise=2e-1,
                  min_radius=0.9, max_radius=1.2, t_span=None, verbose=False, **kwargs):
    if t_span is None:
        t_span = [0, 5]
    orbit_settings = locals()
    if verbose:
        print("Making a dataset of near-circular 3-body orbits:")

    x, dx, e = [], [], []
    N = timesteps * trials
    while len(x) < N:

        state = random_config(nu=orbit_noise, min_radius=min_radius, max_radius=max_radius)
        orbit, settings = get_orbit(state, t_points=timesteps, t_span=t_span, nbodies=nbodies, **kwargs)
        batch = orbit.transpose(2, 0, 1).reshape(-1, nbodies * 5)

        for state in batch:
            dstate = update(None, state)

            # reshape from [nbodies, state] where state=[m, qx, qy, px, py]
            # to [canonical_coords] = [qx1, qx2, qy1, qy2, px1,px2,....]
            coords = state.reshape(nbodies, 5).T[1:].flatten()
            dcoords = dstate.reshape(nbodies, 5).T[1:].flatten()
            x.append(coords)
            dx.append(dcoords)

            shaped_state = state.copy().reshape(nbodies, 5, 1)
            e.append(total_energy(shaped_state))

    data = {'coords': np.stack(x)[:N],
            'dcoords': np.stack(dx)[:N],
            'energy': np.stack(e)[:N]}
    return data, orbit_settings


##### MAKE A DATASET #####
def make_orbits_dataset(test_split=0.2, **kwargs):
    data, orbit_settings = sample_orbits(**kwargs)

    # make a train/test split
    split_ix = int(data['coords'].shape[0] * test_split)
    split_data = {}
    for k, v in data.items():
        split_data[k], split_data['test_' + k] = v[split_ix:], v[:split_ix]
    data = split_data

    data['meta'] = orbit_settings
    return data


def read_lipson(experiment_name, save_dir):
  desired_file = experiment_name + ".txt"
  with zipfile.ZipFile('{}/invar_datasets.zip'.format(save_dir)) as z:
    for filename in z.namelist():
      if desired_file == filename and not os.path.isdir(filename):
        with z.open(filename) as f:
            data = f.read()
  return str(data)


def str2array(string):
  lines = string.split('\\n')
  names = lines[0].strip("b'% \\r").split(' ')
  dnames = ['d' + n for n in names]
  names = ['trial', 't'] + names + dnames
  data = [[float(s) for s in l.strip("' \\r,").split( )] for l in lines[1:-1]]

  return np.asarray(data), names
